{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## lab1 Mnist分类任务"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入库，下载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.6.0+cu124\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import requests\n",
    "\n",
    "DATA_PATH = Path(\"data\")\n",
    "PATH = DATA_PATH / \"mnist\"\n",
    "\n",
    "PATH.mkdir(parents=True, exist_ok=True)\n",
    "URL = \"http://deeplearning.net/data/mnist/\"\n",
    "FILENAME = \"mnist.pkl.gz\"\n",
    "\n",
    "if not (PATH / FILENAME).exists():\n",
    "    content = requests.get(URL + FILENAME).content\n",
    "    (PATH / FILENAME).open(\"wb\").write(content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "\n",
    "with gzip.open((PATH / FILENAME).as_posix(), \"rb\") as f:\n",
    "    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 展示数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 784)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGaxJREFUeJzt3X+QVWX9B/Bn/cGKCksrwrICCqhYIjgZEKmkiSCVI0iNms1gOToYOCqJDU6KVramaQ5Fyh8NZCn+mAlNpqEUZJkScECJcSzGZSgwAZPa5ZeAwvnOOczul1WQzrLLc/fe12vmmcu993z2Hs6ePe/7nPPc55YlSZIEADjCjjrSLwgAKQEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARDFMaHA7N27N7zzzjuhU6dOoaysLPbqAJBTOr/B1q1bQ3V1dTjqqKPaTwCl4dOrV6/YqwHAYVq/fn3o2bNn+zkFl/Z8AGj/DnU8b7MAmjFjRjjttNPCcccdF4YOHRpeffXV/6nOaTeA4nCo43mbBNDTTz8dJk+eHKZNmxZee+21MGjQoDBq1Kjw7rvvtsXLAdAeJW1gyJAhycSJE5vu79mzJ6murk5qamoOWdvQ0JDOzq1pmqaF9t3S4/knafUe0O7du8OKFSvCiBEjmh5LR0Gk95csWfKx5Xft2hW2bNnSrAFQ/Fo9gN57772wZ8+e0L1792aPp/c3btz4seVrampCRUVFUzMCDqA0RB8FN3Xq1NDQ0NDU0mF7ABS/Vv8cUNeuXcPRRx8dNm3a1Ozx9H5VVdXHli8vL88aAKWl1XtAHTp0COedd15YsGBBs9kN0vvDhg1r7ZcDoJ1qk5kQ0iHY48ePD5/73OfCkCFDwiOPPBK2b98evvWtb7XFywHQDrVJAF111VXh3//+d7j77ruzgQfnnntumD9//scGJgBQusrSsdihgKTDsNPRcAC0b+nAss6dOxfuKDgASpMAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCiOifOyUJiOPvro3DUVFRWhUE2aNKlFdccff3zumv79++eumThxYu6an/70p7lrrrnmmtASO3fuzF1z//3356659957QynSAwIgCgEEQHEE0D333BPKysqatbPOOqu1XwaAdq5NrgGdffbZ4aWXXvr/FznGpSYAmmuTZEgDp6qqqi1+NABFok2uAb311luhuro69O3bN1x77bVh3bp1B112165dYcuWLc0aAMWv1QNo6NChYfbs2WH+/Pnh0UcfDWvXrg0XXnhh2Lp16wGXr6mpyYaxNrZevXq19ioBUAoBNHr06PD1r389DBw4MIwaNSr84Q9/CPX19eGZZ5454PJTp04NDQ0NTW39+vWtvUoAFKA2Hx3QpUuXcOaZZ4a6uroDPl9eXp41AEpLm38OaNu2bWHNmjWhR48ebf1SAJRyAN1+++2htrY2/OMf/wivvPJKGDt2bDa9SUunwgCgOLX6Kbi33347C5vNmzeHk08+OVxwwQVh6dKl2b8BoM0C6KmnnmrtH0mB6t27d+6aDh065K75whe+kLsmfePT0muWeY0bN65Fr1Vs0jefeU2fPj13TXpWJa+DjcI9lL/+9a+5a9IzQPxvzAUHQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIoS5IkCQVky5Yt2Vdzc+Sce+65LapbuHBh7hq/2/Zh7969uWu+/e1vt+j7wo6EDRs2tKjuv//9b+6a1atXt+i1ilH6LdedO3c+6PN6QABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBTHxHlZCsm6detaVLd58+bcNWbD3mfZsmW5a+rr63PXXHzxxaEldu/enbvmN7/5TYtei9KlBwRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAojAZKeE///lPi+qmTJmSu+arX/1q7prXX389d8306dPDkbJy5crcNZdeemnumu3bt+euOfvss0NL3HLLLS2qgzz0gACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFGVJkiShgGzZsiVUVFTEXg3aSOfOnXPXbN26NXfNzJkzQ0tcf/31uWu++c1v5q6ZM2dO7hpobxoaGj7xb14PCIAoBBAA7SOAFi9eHC6//PJQXV0dysrKwnPPPdfs+fSM3t133x169OgROnbsGEaMGBHeeuut1lxnAEoxgNIvxRo0aFCYMWPGAZ9/4IEHsi8De+yxx8KyZcvCCSecEEaNGhV27tzZGusLQKl+I+ro0aOzdiBp7+eRRx4J3//+98MVV1yRPfb444+H7t27Zz2lq6+++vDXGICi0KrXgNauXRs2btyYnXZrlI5oGzp0aFiyZMkBa3bt2pWNfNu/AVD8WjWA0vBJpT2e/aX3G5/7qJqamiykGluvXr1ac5UAKFDRR8FNnTo1Gyve2NavXx97lQBobwFUVVWV3W7atKnZ4+n9xuc+qry8PPug0v4NgOLXqgHUp0+fLGgWLFjQ9Fh6TScdDTds2LDWfCkASm0U3LZt20JdXV2zgQcrV64MlZWVoXfv3uHWW28NP/rRj8IZZ5yRBdJdd92VfWZozJgxrb3uAJRSAC1fvjxcfPHFTfcnT56c3Y4fPz7Mnj073HHHHdlnhW688cZQX18fLrjggjB//vxw3HHHte6aA9CumYyUovTggw+2qK7xDVUetbW1uWv2/6jC/2rv3r25ayAmk5ECUJAEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIwmzYFKUTTjihRXUvvPBC7povfvGLuWtGjx6du+ZPf/pT7hqIyWzYABQkAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRmIwU9tOvX7/cNa+99lrumvr6+tw1L7/8cu6a5cuXh5aYMWNG7poCO5RQAExGCkBBEkAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhclI4TCNHTs2d82sWbNy13Tq1CkcKXfeeWfumscffzx3zYYNG3LX0H6YjBSAgiSAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAqTkUIEAwYMyF3z8MMP56655JJLwpEyc+bM3DX33Xdf7pp//etfuWuIw2SkABQkAQRA+wigxYsXh8svvzxUV1eHsrKy8NxzzzV7/rrrrsse379ddtllrbnOAJRiAG3fvj0MGjQozJgx46DLpIGTftFUY5szZ87hricAReaYvAWjR4/O2icpLy8PVVVVh7NeABS5NrkGtGjRotCtW7fQv3//cNNNN4XNmzcfdNldu3ZlI9/2bwAUv1YPoPT0W/rd8AsWLAg/+clPQm1tbdZj2rNnzwGXr6mpyYZdN7ZevXq19ioBUAyn4A7l6quvbvr3OeecEwYOHBj69euX9YoO9JmEqVOnhsmTJzfdT3tAQgig+LX5MOy+ffuGrl27hrq6uoNeL0o/qLR/A6D4tXkAvf3229k1oB49erT1SwFQzKfgtm3b1qw3s3bt2rBy5cpQWVmZtXvvvTeMGzcuGwW3Zs2acMcdd4TTTz89jBo1qrXXHYBSCqDly5eHiy++uOl+4/Wb8ePHh0cffTSsWrUq/PrXvw719fXZh1VHjhwZfvjDH2an2gCgkclIoZ3o0qVL7pp01pKWmDVrVu6adNaTvBYuXJi75tJLL81dQxwmIwWgIAkgAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCF2bCBj9m1a1fummOOyf3tLuHDDz/MXdOS7xZbtGhR7hoOn9mwAShIAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiyD97IHDYBg4cmLvma1/7Wu6awYMHh5ZoycSiLfHmm2/mrlm8eHGbrAtHnh4QAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCZKSwn/79++eumTRpUu6aK6+8MndNVVVVKGR79uzJXbNhw4bcNXv37s1dQ2HSAwIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUZiMlILXkkk4r7nmmha9VksmFj3ttNNCsVm+fHnumvvuuy93ze9///vcNRQPPSAAohBAABR+ANXU1ITBgweHTp06hW7duoUxY8aE1atXN1tm586dYeLEieGkk04KJ554Yhg3blzYtGlTa683AKUUQLW1tVm4LF26NLz44ovhgw8+CCNHjgzbt29vWua2224LL7zwQnj22Wez5d95550WffkWAMUt1yCE+fPnN7s/e/bsrCe0YsWKMHz48NDQ0BB+9atfhSeffDJ86UtfypaZNWtW+PSnP52F1uc///nWXXsASvMaUBo4qcrKyuw2DaK0VzRixIimZc4666zQu3fvsGTJkgP+jF27doUtW7Y0awAUvxYHUPq97Lfeems4//zzw4ABA7LHNm7cGDp06BC6dOnSbNnu3btnzx3sulJFRUVT69WrV0tXCYBSCKD0WtAbb7wRnnrqqcNagalTp2Y9qca2fv36w/p5ABTxB1HTD+vNmzcvLF68OPTs2bPZBwZ3794d6uvrm/WC0lFwB/swYXl5edYAKC25ekBJkmThM3fu3LBw4cLQp0+fZs+fd9554dhjjw0LFixoeiwdpr1u3bowbNiw1ltrAEqrB5SedktHuD3//PPZZ4Ear+uk1246duyY3V5//fVh8uTJ2cCEzp07h5tvvjkLHyPgAGhxAD366KPZ7UUXXdTs8XSo9XXXXZf9+2c/+1k46qijsg+gpiPcRo0aFX75y1/meRkASkBZkp5XKyDpMOy0J0XhS0c35vWZz3wmd80vfvGL3DXp8P9is2zZstw1Dz74YIteKz3L0ZKRsbC/dGBZeibsYMwFB0AUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAAtJ9vRKVwpd/DlNfMmTNb9Frnnntu7pq+ffuGYvPKK6/krnnooYdy1/zxj3/MXfP+++/nroEjRQ8IgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAERhMtIjZOjQoblrpkyZkrtmyJAhuWtOOeWUUGx27NjRorrp06fnrvnxj3+cu2b79u25a6DY6AEBEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgChMRnqEjB079ojUHElvvvlm7pp58+blrvnwww9z1zz00EOhJerr61tUB+SnBwRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAoihLkiQJBWTLli2hoqIi9moAcJgaGhpC586dD/q8HhAAUQggAAo/gGpqasLgwYNDp06dQrdu3cKYMWPC6tWrmy1z0UUXhbKysmZtwoQJrb3eAJRSANXW1oaJEyeGpUuXhhdffDF88MEHYeTIkWH79u3NlrvhhhvChg0bmtoDDzzQ2usNQCl9I+r8+fOb3Z89e3bWE1qxYkUYPnx40+PHH398qKqqar21BKDoHHW4IxxSlZWVzR5/4oknQteuXcOAAQPC1KlTw44dOw76M3bt2pWNfNu/AVACkhbas2dP8pWvfCU5//zzmz0+c+bMZP78+cmqVauS3/72t8kpp5ySjB079qA/Z9q0aekwcE3TNC0UV2toaPjEHGlxAE2YMCE59dRTk/Xr13/icgsWLMhWpK6u7oDP79y5M1vJxpb+vNgbTdM0TQttHkC5rgE1mjRpUpg3b15YvHhx6Nmz5ycuO3To0Oy2rq4u9OvX72PPl5eXZw2A0pIrgNIe08033xzmzp0bFi1aFPr06XPImpUrV2a3PXr0aPlaAlDaAZQOwX7yySfD888/n30WaOPGjdnj6dQ5HTt2DGvWrMme//KXvxxOOumksGrVqnDbbbdlI+QGDhzYVv8HANqjPNd9Dnaeb9asWdnz69atS4YPH55UVlYm5eXlyemnn55MmTLlkOcB95cuG/u8paZpmhYOux3q2G8yUgDahMlIAShIAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUBRdASZLEXgUAjsDxvOACaOvWrbFXAYAjcDwvSwqsy7F3797wzjvvhE6dOoWysrJmz23ZsiX06tUrrF+/PnTu3DmUKtthH9thH9thH9uhcLZDGitp+FRXV4ejjjp4P+eYUGDSle3Zs+cnLpNu1FLewRrZDvvYDvvYDvvYDoWxHSoqKg65TMGdggOgNAggAKJoVwFUXl4epk2blt2WMtthH9thH9thH9uh/W2HghuEAEBpaFc9IACKhwACIAoBBEAUAgiAKNpNAM2YMSOcdtpp4bjjjgtDhw4Nr776aig199xzTzY7xP7trLPOCsVu8eLF4fLLL88+VZ3+n5977rlmz6fjaO6+++7Qo0eP0LFjxzBixIjw1ltvhVLbDtddd93H9o/LLrssFJOampowePDgbKaUbt26hTFjxoTVq1c3W2bnzp1h4sSJ4aSTTgonnnhiGDduXNi0aVMote1w0UUXfWx/mDBhQigk7SKAnn766TB58uRsaOFrr70WBg0aFEaNGhXefffdUGrOPvvssGHDhqb25z//ORS77du3Z7/z9E3IgTzwwANh+vTp4bHHHgvLli0LJ5xwQrZ/pAeiUtoOqTRw9t8/5syZE4pJbW1tFi5Lly4NL774Yvjggw/CyJEjs23T6LbbbgsvvPBCePbZZ7Pl06m9rrzyylBq2yF1ww03NNsf0r+VgpK0A0OGDEkmTpzYdH/Pnj1JdXV1UlNTk5SSadOmJYMGDUpKWbrLzp07t+n+3r17k6qqquTBBx9seqy+vj4pLy9P5syZk5TKdkiNHz8+ueKKK5JS8u6772bbora2tul3f+yxxybPPvts0zJ/+9vfsmWWLFmSlMp2SH3xi19MbrnllqSQFXwPaPfu3WHFihXZaZX954tL7y9ZsiSUmvTUUnoKpm/fvuHaa68N69atC6Vs7dq1YePGjc32j3QOqvQ0bSnuH4sWLcpOyfTv3z/cdNNNYfPmzaGYNTQ0ZLeVlZXZbXqsSHsD++8P6Wnq3r17F/X+0PCR7dDoiSeeCF27dg0DBgwIU6dODTt27AiFpOAmI/2o9957L+zZsyd079692ePp/b///e+hlKQH1dmzZ2cHl7Q7fe+994YLL7wwvPHGG9m54FKUhk/qQPtH43OlIj39lp5q6tOnT1izZk248847w+jRo7MD79FHHx2KTTpz/q233hrOP//87ACbSn/nHTp0CF26dCmZ/WHvAbZD6hvf+EY49dRTszesq1atCt/73vey60S/+93vQqEo+ADi/6UHk0YDBw7MAindwZ555plw/fXXR1034rv66qub/n3OOedk+0i/fv2yXtEll1wSik16DSR981UK10Fbsh1uvPHGZvtDOkgn3Q/SNyfpflEICv4UXNp9TN+9fXQUS3q/qqoqlLL0Xd6ZZ54Z6urqQqlq3AfsHx+XnqZN/36Kcf+YNGlSmDdvXnj55ZebfX1L+jtPT9vX19eXxP4w6SDb4UDSN6ypQtofCj6A0u70eeedFxYsWNCsy5neHzZsWChl27Zty97NpO9sSlV6uik9sOy/f6RfyJWOhiv1/ePtt9/OrgEV0/6Rjr9ID7pz584NCxcuzH7/+0uPFccee2yz/SE97ZReKy2m/SE5xHY4kJUrV2a3BbU/JO3AU089lY1qmj17dvLmm28mN954Y9KlS5dk48aNSSn57ne/myxatChZu3Zt8pe//CUZMWJE0rVr12wETDHbunVr8vrrr2ct3WUffvjh7N///Oc/s+fvv//+bH94/vnnk1WrVmUjwfr06ZO8//77Salsh/S522+/PRvple4fL730UvLZz342OeOMM5KdO3cmxeKmm25KKioqsr+DDRs2NLUdO3Y0LTNhwoSkd+/eycKFC5Ply5cnw4YNy1oxuekQ26Guri75wQ9+kP3/0/0h/dvo27dvMnz48KSQtIsASv385z/PdqoOHTpkw7KXLl2alJqrrroq6dGjR7YNTjnllOx+uqMVu5dffjk74H60pcOOG4di33XXXUn37t2zNyqXXHJJsnr16qSUtkN64Bk5cmRy8sknZ8OQTz311OSGG24oujdpB/r/p23WrFlNy6RvPL7zne8kn/rUp5Ljjz8+GTt2bHZwLqXtsG7duixsKisrs7+J008/PZkyZUrS0NCQFBJfxwBAFAV/DQiA4iSAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIMTwfwuo74MNPBzYAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "plt.imshow(x_train[0].reshape((28, 28)), cmap=\"gray\")\n",
    "print(x_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])\n",
      "torch.Size([50000, 784])\n",
      "tensor(0) tensor(9)\n"
     ]
    }
   ],
   "source": [
    "x_train, y_train, x_valid, y_valid = map(\n",
    "    torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    ")\n",
    "n, c = x_train.shape\n",
    "x_train, x_train.shape, y_train.min(), y_train.max()\n",
    "print(x_train, y_train)\n",
    "print(x_train.shape)\n",
    "print(y_train.min(), y_train.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb):\n",
    "    return xb.mm(weights) + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(13.9973, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "bs = 64\n",
    "xb = x_train[0:bs]  # a mini-batch from x\n",
    "yb = y_train[0:bs]\n",
    "weights = torch.randn([784, 10], dtype=torch.float32, requires_grad=True)\n",
    "bias = torch.zeros(10, requires_grad=True)\n",
    "\n",
    "print(loss_func(model(xb), yb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class Mnist_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(784, 128)\n",
    "        self.hidden2 = nn.Linear(128, 256)\n",
    "        self.out = nn.Linear(256, 10)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.hidden2(x))\n",
    "        x = self.dropout(x)\n",
    "        x = self.out(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mnist_NN(\n",
      "  (hidden1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (hidden2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Mnist_NN()\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden1.weight Parameter containing:\n",
      "tensor([[ 0.0159,  0.0185, -0.0132,  ...,  0.0144,  0.0066, -0.0144],\n",
      "        [ 0.0345,  0.0200,  0.0224,  ...,  0.0343, -0.0097, -0.0296],\n",
      "        [-0.0141, -0.0032, -0.0117,  ..., -0.0240,  0.0033,  0.0166],\n",
      "        ...,\n",
      "        [ 0.0185, -0.0206,  0.0299,  ...,  0.0044,  0.0331, -0.0279],\n",
      "        [-0.0247,  0.0065,  0.0143,  ...,  0.0115, -0.0126, -0.0357],\n",
      "        [-0.0176, -0.0194,  0.0189,  ..., -0.0195,  0.0283,  0.0005]],\n",
      "       requires_grad=True) torch.Size([128, 784])\n",
      "hidden1.bias Parameter containing:\n",
      "tensor([-0.0176, -0.0059, -0.0279,  0.0168, -0.0063,  0.0159, -0.0286, -0.0156,\n",
      "        -0.0168, -0.0342, -0.0194,  0.0163, -0.0010, -0.0192, -0.0223,  0.0093,\n",
      "        -0.0219,  0.0306, -0.0007, -0.0047, -0.0296,  0.0068, -0.0123, -0.0330,\n",
      "         0.0304,  0.0024,  0.0102,  0.0338,  0.0024,  0.0340, -0.0191,  0.0274,\n",
      "         0.0116,  0.0112, -0.0030, -0.0249, -0.0278, -0.0228,  0.0147, -0.0225,\n",
      "         0.0053,  0.0179, -0.0038,  0.0017,  0.0226, -0.0267,  0.0222, -0.0313,\n",
      "        -0.0167, -0.0145, -0.0029, -0.0255,  0.0118,  0.0209,  0.0284,  0.0245,\n",
      "         0.0167,  0.0352, -0.0145,  0.0080,  0.0305,  0.0177, -0.0300, -0.0258,\n",
      "        -0.0032, -0.0207, -0.0194, -0.0342,  0.0172, -0.0276,  0.0327, -0.0110,\n",
      "         0.0157, -0.0171,  0.0028,  0.0344,  0.0262, -0.0163, -0.0209,  0.0268,\n",
      "         0.0215,  0.0244, -0.0298, -0.0023, -0.0341,  0.0060,  0.0068, -0.0018,\n",
      "         0.0330, -0.0030, -0.0074, -0.0223,  0.0101,  0.0128, -0.0333,  0.0338,\n",
      "         0.0347,  0.0052,  0.0302, -0.0201, -0.0039, -0.0240, -0.0284, -0.0339,\n",
      "        -0.0110,  0.0305,  0.0205,  0.0040, -0.0284, -0.0208,  0.0023,  0.0118,\n",
      "        -0.0233, -0.0154,  0.0036, -0.0074, -0.0305,  0.0007,  0.0019, -0.0301,\n",
      "         0.0152,  0.0034,  0.0265,  0.0178, -0.0147,  0.0185,  0.0113, -0.0172],\n",
      "       requires_grad=True) torch.Size([128])\n",
      "hidden2.weight Parameter containing:\n",
      "tensor([[-0.0399,  0.0785,  0.0299,  ..., -0.0840, -0.0206,  0.0376],\n",
      "        [-0.0584, -0.0315, -0.0480,  ..., -0.0351,  0.0027, -0.0434],\n",
      "        [ 0.0520, -0.0327,  0.0551,  ..., -0.0165, -0.0593,  0.0246],\n",
      "        ...,\n",
      "        [-0.0044,  0.0139, -0.0457,  ...,  0.0741, -0.0391,  0.0680],\n",
      "        [-0.0569,  0.0259,  0.0642,  ...,  0.0386,  0.0801,  0.0686],\n",
      "        [ 0.0469, -0.0296,  0.0658,  ...,  0.0029, -0.0412, -0.0746]],\n",
      "       requires_grad=True) torch.Size([256, 128])\n",
      "hidden2.bias Parameter containing:\n",
      "tensor([-0.0535, -0.0024, -0.0180,  0.0867,  0.0803, -0.0845,  0.0038, -0.0506,\n",
      "        -0.0159, -0.0067,  0.0055, -0.0090,  0.0385, -0.0758, -0.0483,  0.0366,\n",
      "        -0.0208,  0.0346,  0.0722, -0.0590, -0.0654, -0.0139,  0.0346,  0.0795,\n",
      "        -0.0065,  0.0388,  0.0233, -0.0364, -0.0875,  0.0438,  0.0769, -0.0703,\n",
      "        -0.0198, -0.0223, -0.0539, -0.0489, -0.0338, -0.0849,  0.0141,  0.0259,\n",
      "        -0.0020, -0.0501,  0.0461,  0.0825, -0.0411,  0.0500,  0.0042, -0.0809,\n",
      "        -0.0881, -0.0316,  0.0182, -0.0837,  0.0881, -0.0583, -0.0288, -0.0031,\n",
      "         0.0668,  0.0823, -0.0149, -0.0772, -0.0254, -0.0157, -0.0054, -0.0573,\n",
      "         0.0546, -0.0594,  0.0277, -0.0413, -0.0823, -0.0444,  0.0048, -0.0175,\n",
      "         0.0578, -0.0408,  0.0818,  0.0016, -0.0103,  0.0559,  0.0741, -0.0536,\n",
      "         0.0305,  0.0814, -0.0830,  0.0295,  0.0659, -0.0760,  0.0842,  0.0878,\n",
      "         0.0075,  0.0159, -0.0805,  0.0253, -0.0450, -0.0612,  0.0299, -0.0788,\n",
      "         0.0760,  0.0441, -0.0318, -0.0093, -0.0310,  0.0640,  0.0043,  0.0602,\n",
      "         0.0587, -0.0662,  0.0017, -0.0181, -0.0426,  0.0451, -0.0834,  0.0007,\n",
      "        -0.0580,  0.0127, -0.0211,  0.0120,  0.0223,  0.0190,  0.0388,  0.0693,\n",
      "        -0.0102,  0.0180, -0.0730,  0.0116, -0.0243,  0.0411, -0.0803, -0.0256,\n",
      "        -0.0309,  0.0145,  0.0347,  0.0678,  0.0055, -0.0058,  0.0642,  0.0587,\n",
      "        -0.0402, -0.0870, -0.0172, -0.0685,  0.0296,  0.0808,  0.0141, -0.0108,\n",
      "         0.0451,  0.0746,  0.0167,  0.0291, -0.0136,  0.0730,  0.0756, -0.0372,\n",
      "        -0.0673, -0.0013,  0.0360, -0.0481, -0.0331,  0.0561, -0.0709,  0.0402,\n",
      "        -0.0766,  0.0198, -0.0502, -0.0150, -0.0685, -0.0047, -0.0180,  0.0751,\n",
      "        -0.0835, -0.0592,  0.0349,  0.0487,  0.0257,  0.0072, -0.0480, -0.0219,\n",
      "        -0.0063,  0.0537, -0.0644,  0.0635,  0.0268,  0.0253,  0.0165,  0.0612,\n",
      "         0.0668,  0.0264, -0.0448, -0.0558, -0.0099, -0.0323, -0.0051, -0.0337,\n",
      "        -0.0317,  0.0697,  0.0335,  0.0798,  0.0779, -0.0670, -0.0458,  0.0525,\n",
      "        -0.0557,  0.0465, -0.0116,  0.0205,  0.0873,  0.0694,  0.0197, -0.0520,\n",
      "         0.0290, -0.0220, -0.0524,  0.0800, -0.0693,  0.0414, -0.0498, -0.0540,\n",
      "         0.0343,  0.0672,  0.0724, -0.0506,  0.0290, -0.0513, -0.0096, -0.0787,\n",
      "        -0.0553,  0.0430,  0.0532, -0.0332, -0.0646,  0.0639, -0.0610,  0.0038,\n",
      "         0.0650,  0.0206, -0.0332,  0.0610,  0.0233,  0.0833, -0.0359,  0.0451,\n",
      "         0.0333, -0.0397,  0.0189,  0.0841,  0.0207,  0.0239,  0.0326, -0.0343,\n",
      "         0.0601,  0.0525,  0.0006,  0.0765,  0.0806, -0.0745, -0.0057, -0.0130],\n",
      "       requires_grad=True) torch.Size([256])\n",
      "out.weight Parameter containing:\n",
      "tensor([[-0.0328,  0.0419,  0.0003,  ..., -0.0156,  0.0574, -0.0395],\n",
      "        [ 0.0380,  0.0398,  0.0610,  ...,  0.0340, -0.0116,  0.0489],\n",
      "        [-0.0515,  0.0246,  0.0003,  ..., -0.0030, -0.0396,  0.0247],\n",
      "        ...,\n",
      "        [-0.0482,  0.0194, -0.0091,  ...,  0.0284,  0.0461, -0.0085],\n",
      "        [-0.0229, -0.0518, -0.0080,  ..., -0.0121, -0.0497,  0.0089],\n",
      "        [ 0.0353,  0.0330,  0.0253,  ...,  0.0510,  0.0583,  0.0154]],\n",
      "       requires_grad=True) torch.Size([10, 256])\n",
      "out.bias Parameter containing:\n",
      "tensor([-0.0528, -0.0482,  0.0245, -0.0562, -0.0329, -0.0072, -0.0345, -0.0243,\n",
      "        -0.0396, -0.0221], requires_grad=True) torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for name, parameters in net.named_parameters():\n",
    "    print(name, parameters, parameters.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用TensorDataset和TensorLoader来简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "\n",
    "valid_ds = TensorDataset(x_valid, y_valid)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs * 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(train_ds, valid_ds, bs):\n",
    "    return (\n",
    "        DataLoader(train_ds, batch_size=bs, shuffle=True),\n",
    "        DataLoader(valid_ds, batch_size=bs * 2),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(steps, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for step in range(steps):\n",
    "        for xb, yb in train_dl:\n",
    "            loss_batch(model, loss_func, xb, yb, opt)\n",
    "        \n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses, nums = zip(\n",
    "                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]\n",
    "            )\n",
    "        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)\n",
    "        print('当前step: ' + str(step), '验证集损失: ' + str(val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "def get_model():\n",
    "    model = Mnist_NN()\n",
    "    return model, optim.SGD(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)\n",
    "    if opt is not None:\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "        \n",
    "    return loss.item(), len(xb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前step: 0 验证集损失: 2.2851624969482423\n",
      "当前step: 1 验证集损失: 2.257579857635498\n",
      "当前step: 2 验证集损失: 2.2177301624298096\n",
      "当前step: 3 验证集损失: 2.1537818508148194\n",
      "当前step: 4 验证集损失: 2.0469440422058107\n",
      "当前step: 5 验证集损失: 1.8742363395690917\n",
      "当前step: 6 验证集损失: 1.6305477025985717\n",
      "当前step: 7 验证集损失: 1.3566381969451904\n",
      "当前step: 8 验证集损失: 1.1133916572570801\n",
      "当前step: 9 验证集损失: 0.9288397747039795\n",
      "当前step: 10 验证集损失: 0.7966901840209961\n",
      "当前step: 11 验证集损失: 0.7022734102249145\n",
      "当前step: 12 验证集损失: 0.6328500133514404\n",
      "当前step: 13 验证集损失: 0.5814859872817993\n",
      "当前step: 14 验证集损失: 0.5411100621700287\n",
      "当前step: 15 验证集损失: 0.5097664078235626\n",
      "当前step: 16 验证集损失: 0.4841191698074341\n",
      "当前step: 17 验证集损失: 0.46355148544311525\n",
      "当前step: 18 验证集损失: 0.4459900721549988\n",
      "当前step: 19 验证集损失: 0.4306721311092377\n"
     ]
    }
   ],
   "source": [
    "train_dl, valid_dl = get_data(train_ds, valid_ds, bs)\n",
    "model, opt = get_model()\n",
    "fit(20, model, loss_func, opt, train_dl, valid_dl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "准确率: 88.61%\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "for xb, yb in valid_dl:\n",
    "    outputs = model(xb)\n",
    "    _, predicted = torch.max(outputs.data, 1)\n",
    "    total += yb.size(0)\n",
    "    correct += (predicted == yb).sum().item()\n",
    "print('准确率: ' + str(correct / total * 100) + '%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测这是： 5\n",
      "实际上这是： 5\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGWtJREFUeJzt3X2MFdXhP+CzIKwo7NIVYVl5EcSXRgVTq5QqFJWCtCGi/uFbGmwMFgqmiC8NjYq2TVYxsUZKtU0aqSm+lKRoNIZGUSC0oAFLiG2lLGJBeVFJ2AUsqMv8MuOP/bIC2rnucu7e+zzJyd25M+fO7OzsfO6ZOffciiRJkgAAx1inY71CAEgJIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKI4LRebAgQNh69atoUePHqGioiL25gCQUzq+we7du0NdXV3o1KlTxwmgNHz69+8fezMA+Iq2bNkS+vXr13EuwaUtHwA6vi87n7dbAM2bNy+ceuqp4fjjjw/Dhw8Pr7/++v9Uz2U3gNLwZefzdgmgZ555JsycOTPMnj07vPHGG2HYsGFh3Lhx4f3332+P1QHQESXt4MILL0ymTZvWMt3c3JzU1dUl9fX1X1q3sbExHZ1bURRFCR27pOfzL9LmLaCPP/44rFmzJowZM6blubQXRDq9cuXKw5bfv39/aGpqalUAKH1tHkAffvhhaG5uDn369Gn1fDq9ffv2w5avr68P1dXVLUUPOIDyEL0X3KxZs0JjY2NLSbvtAVD62vxzQL169QqdO3cOO3bsaPV8Ol1bW3vY8pWVlVkBoLy0eQuoa9eu4fzzzw9LlixpNbpBOj1ixIi2Xh0AHVS7jISQdsGeNGlS+OY3vxkuvPDC8PDDD4e9e/eGH/7wh+2xOgA6oHYJoGuuuSZ88MEH4Z577sk6Hpx33nlh8eLFh3VMAKB8VaR9sUMRSbthp73hAOjY0o5lVVVVxdsLDoDyJIAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAApRFA9957b6ioqGhVzjrrrLZeDQAd3HHt8aJnn312ePnll/9vJce1y2oA6MDaJRnSwKmtrW2PlwagRLTLPaANGzaEurq6MHjw4HDDDTeEzZs3H3XZ/fv3h6amplYFgNLX5gE0fPjwMH/+/LB48eLw6KOPhk2bNoWRI0eG3bt3H3H5+vr6UF1d3VL69+/f1psEQBGqSJIkac8V7Nq1KwwcODA89NBD4aabbjpiCygtB6UtICEE0PE1NjaGqqqqo85v994BPXv2DGeccUZoaGg44vzKysqsAFBe2v1zQHv27AkbN24Mffv2be9VAVDOAXT77beHZcuWhXfeeSf87W9/C1deeWXo3LlzuO6669p6VQB0YG1+Ce7dd9/Nwmbnzp3h5JNPDhdffHFYtWpV9jMAHLNOCHmlnRDS3nDAV9e9e/eC6l166aW56yxfvrygTkqlZsCAAbnrnHfeebnrbN26NRRi9erVoVg6IRgLDoAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABE0e5fSAfEs2DBgoLqTZgwIXedGTNm5K6zYcOG3HXSL7g8FoN9pkaPHp27Tl1dXe46Xbp0yV2nubk5d51C19VetIAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAojIYNHcT48eNz17n44ovDsfLwww8fs3URwosvvhg6Oi0gAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFwUghgv79++eus2DBgtx1evbsmbtOKXrvvfcKqvePf/wjd51FixblrrNixYrcdTZs2BA6Oi0gAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFwUghgrlz5x6TgUX37t0bCvHGG2/krvPyyy/nrrNw4cLcdXbu3Jm7zr59+0Ihdu/eXVA9/jdaQABEIYAA6BgBtHz58jBhwoRQV1cXKioqwrPPPttqfpIk4Z577gl9+/YN3bp1C2PGjCmJ760AIHIApdeUhw0bFubNm3fE+XPmzAmPPPJIeOyxx8Jrr70WTjzxxDBu3LiCr8ECUJpyd0IYP358Vo4kbf08/PDD4a677gpXXHFF9twTTzwR+vTpk7WUrr322q++xQCUhDa9B7Rp06awffv27LLbQdXV1WH48OFh5cqVR6yzf//+0NTU1KoAUPraNIDS8EmlLZ5DpdMH531efX19FlIHS//+/dtykwAoUtF7wc2aNSs0Nja2lC1btsTeJAA6WgDV1tZmjzt27Gj1fDp9cN7nVVZWhqqqqlYFgNLXpgE0aNCgLGiWLFnS8lx6TyftDTdixIi2XBUA5dYLbs+ePaGhoaFVx4O1a9eGmpqaMGDAgDBjxozwy1/+Mpx++ulZIN19993ZZ4YmTpzY1tsOQDkF0OrVq8Mll1zSMj1z5szscdKkSWH+/PnhzjvvzD4rdPPNN4ddu3aFiy++OCxevDgcf/zxbbvlAHRoFUn64Z0ikl6yS3vDQQw9evTIXWf69Om569x333256xx3XP6xg0eNGhUKsWLFioLqwaHSjmVfdF8/ei84AMqTAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUeQfXhc6gO9+97sF1bvrrrty1xk5cmQoVuedd15B9YyGzbGgBQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAoqhIkiQJRaSpqSlUV1fH3gzaSY8ePXLXeeCBB3LXmTx5cihE586dQynZuXNnQfWGDh2au862bdsKWhelq7GxMVRVVR11vhYQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjiuDirpVwHFl2zZk3uOkOGDAnHyt69e3PXmTt3bu46//73v3PXmTNnTu46vXr1CoXo3bt37joGIyUvLSAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIXBSCnYokWLinZg0UIGFU394Ac/yF3n2WefDcfC9ddfn7vOmDFjjtlAs5CXFhAAUQggADpGAC1fvjxMmDAh1NXVhYqKisMuP9x4443Z84eWyy+/vC23GYByDKD02vqwYcPCvHnzjrpMGjjpl1MdLE899dRX3U4Ayr0Twvjx47PyRSorK0Ntbe1X2S4ASly73ANaunRp9pW+Z555Zpg6dWrYuXPnUZfdv39/aGpqalUAKH1tHkDp5bcnnngiLFmyJDzwwANh2bJlWYupubn5iMvX19eH6urqltK/f/+23iQAyuFzQNdee23Lz+eee24YOnRoOO2007JW0WWXXXbY8rNmzQozZ85smU5bQEIIoPS1ezfswYMHh169eoWGhoaj3i+qqqpqVQAofe0eQO+++252D6hv377tvSoASvkS3J49e1q1ZjZt2hTWrl0bampqsnLfffeFq6++OusFt3HjxnDnnXdmw6+MGzeurbcdgHIKoNWrV4dLLrmkZfrg/ZtJkyaFRx99NKxbty784Q9/CLt27co+rDp27Njwi1/8IrvUBgAFB9Do0aNDkiRHnf+Xv/wl70vShtIOH3m9/fbbBa3r29/+djgWPvjgg9x1fvSjHxW0rmM1sGixGzlyZO46K1asaJdtoXQZCw6AKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAACiNr+Qmrg0bNuSu07Vr14LWdehXqf+vBg0alLvOI488krvOe++9F4pZIfs8/XbhY+XFF188ZuuifGkBARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoDEZaYioqKo7JoKKpBx98MHedJEkKWlepue66647JYKSFDsr64YcfFlQP8tACAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRGIy0xDQ2Nuauc//99xe0rj59+uSuc++99+aus3v37lDMhgwZckwGcv30009z17nzzjtDIQodxBTy0AICIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFFUJEmShCLS1NQUqqurY29Gh3XyySfnrvPOO+8UtK5u3brlrvPWW2/lrvPYY4/lrlPoMXTJJZfkrvOtb30rd53jjss/DvBtt92Wu87cuXNz14G2HBy5qqrqqPO1gACIQgABUPwBVF9fHy644ILQo0eP0Lt37zBx4sSwfv36Vsvs27cvTJs2LZx00kmhe/fu4eqrrw47duxo6+0GoJwCaNmyZVm4rFq1Krz00kvhk08+CWPHjg179+5tWebWW28Nzz//fFi4cGG2/NatW8NVV13VHtsOQAeW607o4sWLW03Pnz8/awmtWbMmjBo1Krvh9Pvf/z48+eST4dJLL82Wefzxx8PXv/71LLQKuVkLQGnq1BZf/1xTU5M9pkGUtorGjBnTssxZZ50VBgwYEFauXHnE19i/f3/W8+3QAkDpKziADhw4EGbMmBEuuuiicM4552TPbd++PXTt2jX07Nmz1bJ9+vTJ5h3tvlLaZfZg6d+/f6GbBEA5BFB6L+jNN98MTz/99FfagFmzZmUtqYNly5YtX+n1AOgY8n8aLoQwffr08MILL4Tly5eHfv36tTxfW1sbPv7447Br165WraC0F1w670gqKyuzAkB5ydUCSgdNSMNn0aJF4ZVXXgmDBg1qNf/8888PXbp0CUuWLGl5Lu2mvXnz5jBixIi222oAyqsFlF52S3u4Pffcc9lngQ7e10nv3aTDsqSPN910U5g5c2bWMSEdguGWW27JwkcPOAAKDqBHH300exw9enSr59Ou1jfeeGP2869+9avQqVOn7AOoaQ+3cePGhd/85jd5VgNAGTAYKdmbhEKk9wHz6ty5cyg16RutvO64447cdX7961/nrgMxGYwUgKIkgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFEbDpmB333137jq333577jrdu3fPXefTTz8Nhfjd736Xu076FSR5vf3227nrQEdjNGwAipIAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgMRkrRu/zyy3PXeeuttwpa1zvvvFNQPeBwBiMFoCgJIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCYKQAtAuDkQJQlAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACoPgDqL6+PlxwwQWhR48eoXfv3mHixIlh/fr1rZYZPXp0qKioaFWmTJnS1tsNQDkF0LJly8K0adPCqlWrwksvvRQ++eSTMHbs2LB3795Wy02ePDls27atpcyZM6ettxuADu64PAsvXry41fT8+fOzltCaNWvCqFGjWp4/4YQTQm1tbdttJQAlp9NX/brVVE1NTavnFyxYEHr16hXOOeecMGvWrPDRRx8d9TX279+ffQ33oQWAMpAUqLm5Ofn+97+fXHTRRa2e/+1vf5ssXrw4WbduXfLHP/4xOeWUU5Irr7zyqK8ze/bsJN0MRVEUJZRUaWxs/MIcKTiApkyZkgwcODDZsmXLFy63ZMmSbEMaGhqOOH/fvn3ZRh4s6evF3mmKoihKaPcAynUP6KDp06eHF154ISxfvjz069fvC5cdPnx49tjQ0BBOO+20w+ZXVlZmBYDykiuA0hbTLbfcEhYtWhSWLl0aBg0a9KV11q5dmz327du38K0EoLwDKO2C/eSTT4bnnnsu+yzQ9u3bs+erq6tDt27dwsaNG7P53/ve98JJJ50U1q1bF2699dash9zQoUPb63cAoCPKc9/naNf5Hn/88Wz+5s2bk1GjRiU1NTVJZWVlMmTIkOSOO+740uuAh0qXjX3dUlEURQlfuXzZub/i/wdL0Ui7YactKgA6tvSjOlVVVUedbyw4AKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIougBKkiT2JgBwDM7nRRdAu3fvjr0JAByD83lFUmRNjgMHDoStW7eGHj16hIqKilbzmpqaQv/+/cOWLVtCVVVVKFf2w2fsh8/YD5+xH4pnP6SxkoZPXV1d6NTp6O2c40KRSTe2X79+X7hMulPL+QA7yH74jP3wGfvhM/ZDceyH6urqL12m6C7BAVAeBBAAUXSoAKqsrAyzZ8/OHsuZ/fAZ++Ez9sNn7IeOtx+KrhMCAOWhQ7WAACgdAgiAKAQQAFEIIACi6DABNG/evHDqqaeG448/PgwfPjy8/vrrodzce++92egQh5azzjorlLrly5eHCRMmZJ+qTn/nZ599ttX8tB/NPffcE/r27Ru6desWxowZEzZs2BDKbT/ceOONhx0fl19+eSgl9fX14YILLshGSundu3eYOHFiWL9+fatl9u3bF6ZNmxZOOumk0L1793D11VeHHTt2hHLbD6NHjz7seJgyZUooJh0igJ555pkwc+bMrGvhG2+8EYYNGxbGjRsX3n///VBuzj777LBt27aWsmLFilDq9u7dm/3N0zchRzJnzpzwyCOPhMceeyy89tpr4cQTT8yOj/REVE77IZUGzqHHx1NPPRVKybJly7JwWbVqVXjppZfCJ598EsaOHZvtm4NuvfXW8Pzzz4eFCxdmy6dDe1111VWh3PZDavLkya2Oh/R/pagkHcCFF16YTJs2rWW6ubk5qaurS+rr65NyMnv27GTYsGFJOUsP2UWLFrVMHzhwIKmtrU0efPDBlud27dqVVFZWJk899VRSLvshNWnSpOSKK65Iysn777+f7Ytly5a1/O27dOmSLFy4sGWZf/3rX9kyK1euTMplP6S+853vJD/5yU+SYlb0LaCPP/44rFmzJruscuh4cen0ypUrQ7lJLy2ll2AGDx4cbrjhhrB58+ZQzjZt2hS2b9/e6vhIx6BKL9OW4/GxdOnS7JLMmWeeGaZOnRp27twZSlljY2P2WFNTkz2m54q0NXDo8ZBeph4wYEBJHw+Nn9sPBy1YsCD06tUrnHPOOWHWrFnho48+CsWk6AYj/bwPP/wwNDc3hz59+rR6Pp1+6623QjlJT6rz58/PTi5pc/q+++4LI0eODG+++WZ2LbgcpeGTOtLxcXBeuUgvv6WXmgYNGhQ2btwYfvazn4Xx48dnJ97OnTuHUpOOnD9jxoxw0UUXZSfYVPo379q1a+jZs2fZHA8HjrAfUtdff30YOHBg9oZ13bp14ac//Wl2n+jPf/5zKBZFH0D8n/RkctDQoUOzQEoPsD/96U/hpptuirptxHfttde2/Hzuuedmx8hpp52WtYouu+yyUGrSeyDpm69yuA9ayH64+eabWx0PaSed9DhI35ykx0UxKPpLcGnzMX339vleLOl0bW1tKGfpu7wzzjgjNDQ0hHJ18BhwfBwuvUyb/v+U4vExffr08MILL4RXX3211de3pH/z9LL9rl27yuJ4mH6U/XAk6RvWVDEdD0UfQGlz+vzzzw9Llixp1eRMp0eMGBHK2Z49e7J3M+k7m3KVXm5KTyyHHh/pF3KlveHK/fh49913s3tApXR8pP0v0pPuokWLwiuvvJL9/Q+Vniu6dOnS6nhILzul90pL6XhIvmQ/HMnatWuzx6I6HpIO4Omnn856Nc2fPz/55z//mdx8881Jz549k+3btyfl5LbbbkuWLl2abNq0KfnrX/+ajBkzJunVq1fWA6aU7d69O/n73/+elfSQfeihh7Kf//Of/2Tz77///ux4eO6555J169ZlPcEGDRqU/Pe//03KZT+k826//fasp1d6fLz88svJN77xjeT0009P9u3bl5SKqVOnJtXV1dn/wbZt21rKRx991LLMlClTkgEDBiSvvPJKsnr16mTEiBFZKSVTv2Q/NDQ0JD//+c+z3z89HtL/jcGDByejRo1KikmHCKDU3Llzs4Oqa9euWbfsVatWJeXmmmuuSfr27Zvtg1NOOSWbTg+0Uvfqq69mJ9zPl7Tb8cGu2HfffXfSp0+f7I3KZZddlqxfvz4pp/2QnnjGjh2bnHzyyVk35IEDByaTJ08uuTdpR/r90/L444+3LJO+8fjxj3+cfO1rX0tOOOGE5Morr8xOzuW0HzZv3pyFTU1NTfY/MWTIkOSOO+5IGhsbk2Li6xgAiKLo7wEBUJoEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAIQY/h+8Xbv6MgcW6QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import random\n",
    "n = random.randint(0, x_valid.shape[0])\n",
    "plt.imshow(x_train[n].reshape((28, 28)), cmap=\"gray\")\n",
    "\n",
    "\n",
    "outputs = model(x_train[n])\n",
    "_, predicted = torch.max(outputs.data, 0)\n",
    "print('预测这是：', int(predicted))\n",
    "\n",
    "print('实际上这是：', int(y_train[n]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rob",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
